{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "78ccf609-eff0-46a9-856c-9aab115d06ab",
   "metadata": {},
   "source": [
    "# Stable Diffusion img2img Pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d58a00e5-dbc7-4333-8a9d-572a755b6a1c",
   "metadata": {},
   "source": [
    "## Install libraries "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "31fea7d0-4383-4260-85d1-e16b6e8d3f44",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q transformers diffusers accelerate torch==1.13.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b8a74b62-32ba-4c51-8167-91048ca17e7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q \"ipywidgets>=7,<8\" ftfy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44424596-0d8e-40a9-90e8-d6738d27f3f8",
   "metadata": {},
   "source": [
    "## Authenticate with the Hugging Face Hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "391cdfed-03ed-4fce-a6b3-0caa3221a6b0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7abdb5470ef14221b330335c481787d4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44020274-6802-4f9b-91f8-27225a05a6e2",
   "metadata": {},
   "source": [
    "## Import dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "e43ef6b2-cef2-4869-859f-e11fe58fede8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import StableDiffusionImg2ImgPipeline, EulerDiscreteScheduler\n",
    "from pathlib import Path\n",
    "from PIL import Image\n",
    "import torch\n",
    "import re"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "164e61ad-feb0-4659-ac20-3f5529b87486",
   "metadata": {},
   "source": [
    "## Remove non-word characters and foreign characters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "6884ee77-78ae-4ece-b108-e634472382de",
   "metadata": {},
   "outputs": [],
   "source": [
    "def slugify(text):\n",
    "    text = re.sub(r\"[^\\w\\s]\", \"\", text)\n",
    "    text = re.sub(r\"\\s+\", \"-\", text)\n",
    "    return text"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54f52e1a-9335-4379-97ca-b89bb27a211e",
   "metadata": {},
   "source": [
    "## Define model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "661d2bcd-f517-41f8-a1e6-76d0ead9eb0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_id = \"stabilityai/stable-diffusion-2\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86b6ecdf-049d-4b4f-8eca-4a3f2743c1cd",
   "metadata": {},
   "source": [
    "## Pull down dataset from Roboflow\n",
    "### Add in your API key below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "6c7a3b21-afe7-4655-8946-c84e6df2b8bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading Roboflow workspace...\n",
      "loading Roboflow project...\n",
      "Downloading Dataset Version Zip in Construction-Site-Safety-1 to coco: 100% [27793802 / 27793802] bytes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Extracting Dataset Version Zip to Construction-Site-Safety-1 in coco:: 100%|██████████| 406/406 [00:00<00:00, 6045.21it/s]\n"
     ]
    }
   ],
   "source": [
    "!pip install -q roboflow\n",
    "\n",
    "from roboflow import Roboflow\n",
    "rf = Roboflow(api_key=\"YOUR_API_KEY_HERE\")\n",
    "project = rf.workspace(\"roboflow-universe-projects\").project(\"construction-site-safety\")\n",
    "dataset = project.version(1).download(\"coco\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0880465-2f71-4670-b2b5-aa6cc7418e10",
   "metadata": {},
   "source": [
    "## Create a list of images and put them in proper format/size\n",
    "#### Replace images below with your image locations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "902d6640-aff6-4b26-9f46-7c1f363870e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "images = [\"/home/ec2-user/SageMaker/Construction-Site-Safety-1/train/000246_jpg.rf.803c6bf16e1d86b997796ebb8b4b2152.jpg\",\n",
    "              \"/home/ec2-user/SageMaker/Construction-Site-Safety-1/train/000830_jpg.rf.a21dcfb4aa17f2b4c0d22ba91549b7db.jpg\",\n",
    "              \"/home/ec2-user/SageMaker/Construction-Site-Safety-1/train/004779_jpg.rf.92c537eed971d0111cd63ddf4589d77b.jpg\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "b35d4c81-4067-4d1c-aa04-69203a43f332",
   "metadata": {},
   "outputs": [],
   "source": [
    "init_images = [Image.open(image).convert(\"RGB\").resize((768,768)) for image in images]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09911d6c-c567-4fb0-a0a7-34dfa429b469",
   "metadata": {},
   "source": [
    "## Define prompts with negative prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "9fa99ce8-587d-4fe7-b292-318e9a5fd01f",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompts = [\"construction worker in snowy landscape\",\n",
    "           \"construction worker in dark evening\", \n",
    "           \"construction worker in rain storm\"]\n",
    "\n",
    "negative_prompts = [\"blurry, dark photo, blue\",\n",
    "                    \"blurry, dark photo, blue\",\n",
    "                    \"blurry, dark photo, blue\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bbc7df4-1717-468f-8b8b-93e009def5eb",
   "metadata": {},
   "source": [
    "## Create scheduler and pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "ab1f0413-ef92-4946-86a3-9ca3e744ad49",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "77100cbbc5584445b40ceeb78c74c6c7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)cheduler_config.json:   0%|          | 0.00/345 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "799a01dcb3544aa7a0e943952837ea19",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)ain/model_index.json:   0%|          | 0.00/539 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7a2da229e2914a9ab4ecb732cdeae851",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "64cb41002753450ca43cf303deb1d4d4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)\"pytorch_model.bin\";:   0%|          | 0.00/1.36G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "411c6f8ce1df48ef8955bb5af45a3ce8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)rocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6a1b2cd357084ead964066fa1b6885e5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)_encoder/config.json:   0%|          | 0.00/633 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0ef43b595cfc4cfeb7459a2801c9ca82",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bdd0682cb3994514a619d523dedaeb5a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)cial_tokens_map.json:   0%|          | 0.00/460 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1819c6301e55450ab101dcafeb5be99e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a08f0807c7034399acd455bab3e0fc40",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)okenizer_config.json:   0%|          | 0.00/824 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9a1bc856910e4b9182bfe50f6d7ea37b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)b1b/unet/config.json:   0%|          | 0.00/909 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7581fd9a5b5c43fa90213686ed2ffe61",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)_pytorch_model.bin\";:   0%|          | 0.00/335M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8ae778ac40d441faa876e13494674f15",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)3b1b/vae/config.json:   0%|          | 0.00/611 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "921258a74f124343b8f46855921075bc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)_pytorch_model.bin\";:   0%|          | 0.00/3.46G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.\n"
     ]
    }
   ],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "# Use the Euler scheduler here instead of default\n",
    "scheduler = EulerDiscreteScheduler.from_pretrained(\n",
    "    model_id, subfolder=\"scheduler\")\n",
    "pipe = StableDiffusionImg2ImgPipeline.from_pretrained(\n",
    "    model_id, scheduler=scheduler, torch_dtype=torch.float16)\n",
    "pipe = pipe.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb3c8df3-19df-4b43-b8f0-83261cec45d0",
   "metadata": {},
   "source": [
    "## Create dir for storing generated images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "e16208b3-fbfe-42af-bb7d-f9db4203edba",
   "metadata": {},
   "outputs": [],
   "source": [
    "DIR_NAME=\"./images/\"\n",
    "dirpath = Path(DIR_NAME)\n",
    "# create parent dir if doesn't exist\n",
    "dirpath.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbe82fb6-809f-432d-adf4-a628499c2a21",
   "metadata": {},
   "source": [
    "## Define pipeline parameters and generate images based on prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "761b1d74-2b8b-494b-9b3a-e988a44ddc4f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "87a9eedfa2c34792ac944080149c6e23",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/16 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "steps = 20\n",
    "scale = 9\n",
    "num_images_per_prompt = 1\n",
    "seed = torch.randint(0, 1000000, (1,)).item()\n",
    "generator = torch.Generator(device=device).manual_seed(seed)\n",
    "output = pipe(prompts, negative_prompt=negative_prompts, image=init_images, num_inference_steps=steps,\n",
    "             guidance_scale=scale, num_images_per_prompt=num_images_per_prompt, generator=generator)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b0a5a37-5b7a-4ce2-8c7b-bc6312c255b2",
   "metadata": {},
   "source": [
    "## iterate through and push generated images to images dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "eba9aaa1-de1e-4a9c-89cf-2ba1ebcafc4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx, (image,prompt) in enumerate(zip(output.images, prompts*num_images_per_prompt)):\n",
    "    image_name = f'{slugify(prompt)}-{idx}.png'\n",
    "    image_path = dirpath / image_name\n",
    "    image.save(image_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91147dd2-4633-4433-93a4-9f0b744402d9",
   "metadata": {},
   "source": [
    "## Upload Images to Roboflow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "5066842d-ffe6-4cdc-8890-a7e74c8b12fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "HOME = os.getcwd()\n",
    "image_dir = os.path.join(HOME, \"images\", \"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "78cbc5c0-763f-4385-8791-121920ee6c1d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/home/ec2-user/SageMaker/images/'"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "image_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a8dff51-1418-4956-86f1-cca9d12d3b10",
   "metadata": {},
   "source": [
    "### Add in your API key and project name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "b5991293-c830-49b7-bf64-abb249f40e5e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading Roboflow workspace...\n",
      "loading Roboflow project...\n",
      "*** Processing image [6] - /home/ec2-user/SageMaker/images/construction-worker-in-dark-evening-1.png ***\n",
      "*** Processing image [6] - /home/ec2-user/SageMaker/images/construction-worker-in-snowy-landscape-0.png ***\n",
      "*** Processing image [6] - /home/ec2-user/SageMaker/images/blue-eyes-and-a-ponytail-2.png ***\n",
      "*** Processing image [6] - /home/ec2-user/SageMaker/images/sunglasses-and-a-warm-hat-0.png ***\n",
      "*** Processing image [6] - /home/ec2-user/SageMaker/images/construction-worker-in-rain-storm-2.png ***\n",
      "*** Processing image [6] - /home/ec2-user/SageMaker/images/eye-glasses-and-a-baseball-cap-1.png ***\n"
     ]
    }
   ],
   "source": [
    "import glob\n",
    "\n",
    "## DEFINITIONS\n",
    "# glob params\n",
    "file_extension_type = \".png\"\n",
    "\n",
    "## INIT\n",
    "# roboflow pip params\n",
    "rf = Roboflow(api_key=\"YOUR_API_KEY\")\n",
    "upload_project = rf.workspace().project(\"YOUR_PROJECT\")\n",
    "\n",
    "## MAIN\n",
    "# glob images\n",
    "image_glob = glob.glob(image_dir + '/*' + file_extension_type)\n",
    "\n",
    "# perform upload\n",
    "for image in image_glob:\n",
    "    upload_project.upload(image, num_retry_uploads=3)\n",
    "    print(\"*** Processing image [\" + str(len(image_glob)) + \"] - \" + image + \" ***\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
